{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "226279a0",
   "metadata": {},
   "source": [
    "# ONNX Runtime部署-ImageNet1000-预测单张图像\n",
    "\n",
    "使用推理引擎 ONNX Runtime，读取 ONNX 格式的模型文件，对单张图像文件进行预测。\n",
    "\n",
    "同济子豪兄 https://space.bilibili.com/1900783\n",
    "\n",
    "2022-8-22 2023-5-8 2023-8-1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3a519f3",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 应用场景\n",
    "\n",
    "以下代码在需要部署的硬件上运行（本地PC、手机、嵌入式开发板、树莓派、Jetson Nano、服务器）\n",
    "\n",
    "只需把`onnx`模型文件发到部署硬件上，并安装 ONNX Runtime 环境，用下面几行代码就可以运行模型了。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba714368",
   "metadata": {},
   "source": [
    "## 导入工具包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1d319ddf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnxruntime\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d1fb9ec",
   "metadata": {},
   "source": [
    "## 载入 onnx 模型，获取 ONNX Runtime 推理器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "098a9aa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "ort_session = onnxruntime.InferenceSession('resnet18_imagenet.onnx')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6758cd81",
   "metadata": {},
   "source": [
    "## 构造随机输入，获取输出结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d0e7de63",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(1, 3, 256, 256).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d44c1c51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 3, 256, 256)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4d2acb36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# onnx runtime 输入\n",
    "ort_inputs = {'input': x}\n",
    "\n",
    "# onnx runtime 输出\n",
    "ort_output = ort_session.run(['output'], ort_inputs)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4d64bca",
   "metadata": {},
   "source": [
    "注意，输入输出张量的名称需要和 torch.onnx.export 中设置的输入输出名对应"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4e69b7d4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 1000)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ort_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b4b2e005",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ort_output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98759567",
   "metadata": {},
   "source": [
    "## 载入一张真正的测试图像"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "aa446d8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "img_path = 'banana1.jpg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ffffe346",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 用 pillow 载入\n",
    "from PIL import Image\n",
    "img_pil = Image.open(img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7c791d26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# img_pil"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4601e5b",
   "metadata": {},
   "source": [
    "## 预处理函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a79ee478",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.7/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: \n",
      "  warn(f\"Failed to load image Python extension: {e}\")\n"
     ]
    }
   ],
   "source": [
    "from torchvision import transforms\n",
    "\n",
    "# 测试集图像预处理-RCTN：缩放裁剪、转 Tensor、归一化\n",
    "test_transform = transforms.Compose([transforms.Resize(256),\n",
    "                                     transforms.CenterCrop(256),\n",
    "                                     transforms.ToTensor(),\n",
    "                                     transforms.Normalize(\n",
    "                                         mean=[0.485, 0.456, 0.406], \n",
    "                                         std=[0.229, 0.224, 0.225])\n",
    "                                    ])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5aae48d",
   "metadata": {},
   "source": [
    "## 运行预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cf4de7b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_img = test_transform(img_pil)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2513ea9b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 256, 256])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_img.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "70be81bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_tensor = input_img.unsqueeze(0).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "efbfaa2f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 3, 256, 256)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_tensor.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1259c337",
   "metadata": {},
   "source": [
    "## 推理预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "20491934",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ONNX Runtime 输入\n",
    "ort_inputs = {'input': input_tensor}\n",
    "\n",
    "# ONNX Runtime 输出\n",
    "pred_logits = ort_session.run(['output'], ort_inputs)[0]\n",
    "pred_logits = torch.tensor(pred_logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "536d3cc1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1000])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6f3ba4d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对 logit 分数做 softmax 运算，得到置信度概率\n",
    "pred_softmax = F.softmax(pred_logits, dim=1) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "9b53f3f0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1000])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_softmax.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2af6bda7",
   "metadata": {},
   "source": [
    "## 解析预测结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8d86ec66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 取置信度最高的前 n 个结果\n",
    "n = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "6a53fbc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_n = torch.topk(pred_softmax, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "8433a2f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.return_types.topk(\n",
       "values=tensor([[9.9669e-01, 2.6005e-03, 3.0254e-04]]),\n",
       "indices=tensor([[954, 939, 941]]))"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "top_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "62b5fd54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测类别\n",
    "pred_ids = top_n.indices.numpy()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "13e876fd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([954, 939, 941])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "72536bc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测置信度\n",
    "confs = top_n.values.numpy()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "6731b721",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([9.9668556e-01, 2.6005327e-03, 3.0253988e-04], dtype=float32)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfdaa026",
   "metadata": {},
   "source": [
    "## 载入类别 ID 和 类别名称 对应关系"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "310f07d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('imagenet_class_index.csv')\n",
    "idx_to_labels = {}\n",
    "for idx, row in df.iterrows():\n",
    "    idx_to_labels[row['ID']] = row['class']   # 英文\n",
    "#     idx_to_labels[row['ID']] = row['Chinese'] # 中文"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "2a7af04a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# idx_to_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87a64f58",
   "metadata": {},
   "source": [
    "## 分别用英文和中文打印预测结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "9e7a92db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "banana               99.669\n",
      "zucchini             0.260\n",
      "acorn_squash         0.030\n"
     ]
    }
   ],
   "source": [
    "for i in range(n):\n",
    "    class_name = idx_to_labels[pred_ids[i]] # 获取类别名称\n",
    "    confidence = confs[i] * 100             # 获取置信度\n",
    "    text = '{:<20} {:>.3f}'.format(class_name, confidence)\n",
    "    print(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd4edd1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
